-
Notifications
You must be signed in to change notification settings - Fork 248
feat: retry rollout if generation_logprobs contains NaN #1885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
📝 WalkthroughWalkthroughThe changes introduce retry logic to NeMo-Gym rollout collection to handle NaN values in generation log probabilities. A new configuration field enables retrying rollouts up to a maximum count, with NaN detection in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/nemo_gym/run_grpo_nemo_gym.py (1)
252-257:⚠️ Potential issue | 🟠 MajorAvoid code-level default for rollout_max_retries_to_avoid_lp_nan.
Using
cfg.get(..., 1)introduces a hidden default in code; please read the key directly and define the default in YAML.🔧 Suggested fix
- rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get("rollout_max_retries_to_avoid_lp_nan", 1), + rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg["rollout_max_retries_to_avoid_lp_nan"],As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values" and "Access required config values directly (e.g.,
policy_cfg['precision']) and assume they are present; do not introduce hidden defaults in code."nemo_rl/environments/nemo_gym.py (1)
27-31:⚠️ Potential issue | 🟠 MajorRemove the TypedDict default and mark the field as required.
rollout_max_retries_to_avoid_lp_nan: int = 1sets a non-None default in code, which violates the guideline that YAML is the single source of truth for configuration defaults. In TypedDict, this pattern is also inconsistent with the codebase convention of usingNotRequired[int]for optional fields. Since the field is always explicitly provided at instantiation sites (never relying on the class default), it should be declared as required without a default value.🔧 Suggested fix
class NemoGymConfig(TypedDict): model_name: str base_urls: List[str] initial_global_config_dict: Dict[str, Any] - rollout_max_retries_to_avoid_lp_nan: int = 1 + rollout_max_retries_to_avoid_lp_nan: int
🤖 Fix all issues with AI agents
In `@nemo_rl/environments/nemo_gym.py`:
- Around line 115-152: Validate and clearly define the semantics of max_retries
before entering the loop: ensure cfg["rollout_max_retries_to_avoid_lp_nan"] is
an int >= 1 (or raise a ValueError) so nemo_gym_num_rows is always defined; keep
current semantics as "max attempts" by replacing the while trial < max_retries
loop with a for attempt in range(max_retries): or explicitly document that
max_retries is the total number of attempts, and remove the off‑by‑one
ambiguity; reference the variables max_retries, trial, the while trial <
max_retries loop, and nemo_gym_num_rows when making the check and adjustment.
In `@tests/unit/environments/test_nemo_gym.py`:
- Around line 206-208: Add a Google-style docstring to the pytest fixture
nemo_gym_with_patched_run_examples describing its purpose, parameters (if any)
and return value; place it immediately below the def
nemo_gym_with_patched_run_examples(...) line and follow Google style sections
(Args:, Returns:) and mention that it yields a nemo_gym instance with
RolloutCollectionHelper.run_examples patched for tests so readers and Sphinx can
parse it.
- Around line 247-295: The fixture currently calls context.__enter__ and yields
env but if actor creation or setup fails the function exits before calling
context.__exit__, leaking the patch; wrap the setup and yield in a try/finally
so context.__exit__ is always called: call context.__enter__ first, then create
config and env via NemoGym.options(...).remote and perform
ray.get(env.health_check.remote()) inside the try, yield env as before, and in
the finally ensure you call env.shutdown.remote() and ray.kill(env) only if env
was created, then call context.__exit__(None, None, None) to guarantee the
patch_run_examples context is reverted even on failures.
- Around line 46-47: The mutable module-global run_examples_called should follow
the project's naming and test-safety conventions: rename it to
G_RUN_EXAMPLES_CALLED (upper snake case with G_ prefix) and ensure it's reset
before each test to avoid cross-test leakage; update all references to
run_examples_called in this file (e.g., where it is incremented or asserted) to
G_RUN_EXAMPLES_CALLED and add a pytest fixture or test setup that sets
G_RUN_EXAMPLES_CALLED = 0 before each test runs.
- Around line 222-239: The patched new_run_examples is unpacking awaitables
returned by orig_run_examples and yielding tuples, causing await task to fail in
run_rollouts; instead, for each awaitable "task" from orig_run_examples(self,
examples, head_server_config) create and yield a new async wrapper coroutine (or
future) that awaits the original task, injects the NaN into
result["response"]["output"] (preserve the has_generation_log_probs check and
raise ValueError if none), and then returns the (row, result) pair—i.e., leave
orig_run_examples and run_rollouts semantics intact by yielding an awaitable
that performs the mutation after awaiting the original "task".
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
|
@terrykong the pipeline has passed at commit a3cd107 can you review? |
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally lgtm. small comments
| model_name=policy_generation.cfg["model_name"], | ||
| base_urls=policy_generation.dp_openai_server_base_urls, | ||
| initial_global_config_dict=config["env"]["nemo_gym"], | ||
| rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of things:
- can we move the default to the yaml and avoid defaulting here
- can we add it to the generation config typeddict with docstring
- can we add asserts (==1) that say this has no effect in other places like (when nemo-gym isn't used) just so users are not under the impression that other paths are less stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefe yuki's idea and I now make it as env.nemo_gym.rollout_max_attempts_to_avoid_lp_nan
nemo_rl/environments/nemo_gym.py
Outdated
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe
assert self.cfg["rollout_max_attempts_to_avoid_lp_nan"] >= 1, .....
just to give a nice user error instead of skipping this while loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe good to put this assert at init part instead of here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added the assert in init
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm and left minor comments.
| model_name=policy_generation.cfg["model_name"], | ||
| base_urls=policy_generation.dp_openai_server_base_urls, | ||
| initial_global_config_dict=config["env"]["nemo_gym"], | ||
| rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.
nemo_rl/environments/nemo_gym.py
Outdated
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe good to put this assert at init part instead of here.
Co-authored-by: Terry Kong <[email protected]> Signed-off-by: Guyue Huang <[email protected]>
Co-authored-by: Terry Kong <[email protected]> Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
What does this PR do ?
In nemo_gym environment, when generation_logprobs contains NaN, retry rollout.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests